Read datasets and models
library("DALEX")
library("randomForest")
library("patchwork")
library("ggplot2")
set.seed(1313)
titanic_rf <- randomForest(survived ~ class + gender + age + sibsp + parch + fare + embarked, data = titanic_imputed)
henry <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew",
"engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 47,
sibsp = 0,
parch = 0,
fare = 25,
embarked = factor("Cherbourg", levels = c("Belfast",
"Cherbourg","Queenstown","Southampton"))
)
henry
johnny_d <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew",
"engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast",
"Cherbourg","Queenstown","Southampton"))
)
johnny_d
titanic_rf_exp <- DALEX::explain(model = titanic_rf,
data = titanic_imputed[, -9],
y = titanic_imputed$survived,
label = "Random Forest")
## Preparation of a new explainer is initiated
## -> model label : Random Forest
## -> data : 2207 rows 8 cols
## -> target variable : 2207 values
## -> predict function : yhat.randomForest will be used ( [33m default [39m )
## -> predicted values : numerical, min = 0.01590278 , mean = 0.3222722 , max = 0.9900173
## -> model_info : package randomForest , ver. 4.6.14 , task regression ( [33m default [39m )
## -> residual function : difference between y and yhat ( [33m default [39m )
## -> residuals : numerical, min = -0.7970723 , mean = -0.0001153935 , max = 0.8992474
## [32m A new explainer has been created! [39m
titanic_rf_exp$model_info$type = "classification"
bd_rf <- predict_parts(explainer = titanic_rf_exp,
new_observation = johnny_d,
keep_distributions = TRUE,
order = c("class","age","gender","fare","parch","sibsp","embarked"),
type = "break_down")
bd_rf
Plot the break down plots
plot(bd_rf)
plot(bd_rf, plot_distributions = TRUE)
perdict_parts() functionbd_rf <- predict_parts(explainer = titanic_rf_exp,
new_observation = henry,
type = "break_down")
bd_rf
Plot the break down plots
plot(bd_rf)
predict_parts() functionbd_rf_order <- predict_parts(explainer = titanic_rf_exp,
new_observation = henry,
type = "break_down",
order = c("class", "age", "gender", "fare", "parch", "sibsp", "embarked"))
plot(bd_rf_order, max_features = 3)
bd_rf_distr <- predict_parts(explainer = titanic_rf_exp,
new_observation = henry,
type = "break_down",
order = c("class", "age", "gender", "fare", "parch", "sibsp", "embarked"),
keep_distributions = TRUE)
plot(bd_rf_distr, plot_distributions = TRUE)
bd_rf <- predict_parts(explainer = titanic_rf_exp,
new_observation = johnny_d,
type = "break_down_interactions")
bd_rf
plot(bd_rf)
bd_rf <- predict_parts(explainer = titanic_rf_exp,
new_observation = henry,
type = "break_down_interactions")
bd_rf
plot(bd_rf)
set.seed(13)
rsample <- lapply(1:10, function(i){
new_order <- sample(1:7)
bd <- predict_parts(titanic_rf_exp, johnny_d, order = new_order)
bd$variable <- as.character(bd$variable)
bd$variable[bd$variable == "embarked = Southampton"] = "embarked = S"
bd$label = paste("random order no.", i)
plot(bd) + scale_y_continuous(limits = c(0.1, 0.6), name = "", breaks = seq(0.1, 0.6, 0.1))
})
rsample[[1]] +
rsample[[2]] +
rsample[[3]] +
rsample[[4]] +
rsample[[5]] +
rsample[[6]] +
rsample[[7]] +
rsample[[8]] +
rsample[[9]] +
rsample[[10]] + plot_layout(ncol = 2)
shap_johnny <- predict_parts(titanic_rf_exp,
new_observation = johnny_d,
B = 25,
type = "shap")
predict(titanic_rf_exp, henry)
## 1
## 0.3081968
shap_henry <- predict_parts(explainer = titanic_rf_exp,
new_observation = henry,
type = "shap",
B = 25)
shap_henry
plot(shap_henry)
plot(shap_henry, show_boxplots = FALSE)
sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Catalina 10.15.3
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] ggplot2_3.3.2 patchwork_1.0.1 randomForest_4.6-14
## [4] DALEX_1.3.1
##
## loaded via a namespace (and not attached):
## [1] pillar_1.4.6 compiler_4.0.2 tools_4.0.2 digest_0.6.25
## [5] jsonlite_1.7.0 evaluate_0.14 lifecycle_0.2.0 tibble_3.0.3
## [9] gtable_0.3.0 pkgconfig_2.0.3 png_0.1-7 rlang_0.4.7
## [13] yaml_2.2.1 xfun_0.15 withr_2.2.0 stringr_1.4.0
## [17] dplyr_1.0.0 knitr_1.29 generics_0.0.2 vctrs_0.3.2
## [21] grid_4.0.2 tidyselect_1.1.0 glue_1.4.1 R6_2.4.1
## [25] rmarkdown_2.3 iBreakDown_1.3.0 farver_2.0.3 purrr_0.3.4
## [29] magrittr_1.5 scales_1.1.1 ellipsis_0.3.1 htmltools_0.5.0
## [33] colorspace_1.4-1 labeling_0.3 stringi_1.4.6 munsell_0.5.0
## [37] crayon_1.3.4